{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tOoyQ70H00_s"
   },
   "source": [
    "## 练习2\n",
    "在课程中，已经学习了如何用Fashion MNIST数据集训练分类器。Fashion MNIST是一个包含服装项目的数据集。还有另一个类似的数据集叫MNIST，它包含很多手写数字（0到9）的图片和标签。\n",
    "\n",
    "编写一个MNIST分类器，让它可以训练到99%以上的准确率，并且在达到这个准确率时，就通过回调函数停止训练。\n",
    "\n",
    "一些注意事项。\n",
    "1. 它应该在小于10个epochs的情况下成功，所以把epochs改成10个也可以，但不能大过10个。\n",
    "2. 当它达到99%以上时，应该打印出 \"达到99%准确率，所以取消训练！\"的字符串。\n",
    "3. 如果你添加了任何额外的变量，请确保你使用与类中使用的变量相同的名称。\n",
    "\n",
    "我已经为你准备了下面的代码--请完成这个程序。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 28, 28)\n",
      "(60000,)\n",
      "Train on 60000 samples\n",
      "Epoch 1/10\n",
      "60000/60000 [==============================] - 6s 107us/sample - loss: 2.6857 - accuracy: 0.9062\n",
      "Epoch 2/10\n",
      "60000/60000 [==============================] - 5s 88us/sample - loss: 0.3357 - accuracy: 0.9365\n",
      "Epoch 3/10\n",
      "60000/60000 [==============================] - 5s 88us/sample - loss: 0.2947 - accuracy: 0.9420s - los - ETA: 0s - loss: 0.2911 - accuracy - ETA: 0s - loss: 0.290 - ETA: 0s - loss: 0.2951 - accuracy: 0.94 - ETA: 0s - loss: 0.2946 - accuracy: 0.\n",
      "Epoch 4/10\n",
      "60000/60000 [==============================] - 5s 87us/sample - loss: 0.2605 - accuracy: 0.9447\n",
      "Epoch 5/10\n",
      "60000/60000 [==============================] - 5s 88us/sample - loss: 0.2491 - accuracy: 0.9498\n",
      "Epoch 6/10\n",
      "60000/60000 [==============================] - 5s 91us/sample - loss: 0.2197 - accuracy: 0.9530\n",
      "Epoch 7/10\n",
      "60000/60000 [==============================] - 5s 85us/sample - loss: 0.2053 - accuracy: 0.9562\n",
      "Epoch 8/10\n",
      "60000/60000 [==============================] - 5s 81us/sample - loss: 0.2082 - accuracy: 0.9589\n",
      "Epoch 9/10\n",
      "60000/60000 [==============================] - 5s 81us/sample - loss: 0.1863 - accuracy: 0.9622\n",
      "Epoch 10/10\n",
      "60000/60000 [==============================] - 5s 84us/sample - loss: 0.1867 - accuracy: 0.9629\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x1fbaff62148>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# 定义callback类\n",
    "class myCallBack(tf.keras.callbacks.Callback):\n",
    "    \"\"\"自定义的停止类\n",
    "\n",
    "    Args:\n",
    "        tf (_type_): 必须继承的类\n",
    "    \"\"\"\n",
    "    def on_epoch_end(self, epoch,logs={}):\n",
    "        \"\"\"需要实现找个方法，才能停止\n",
    "\n",
    "        Args:\n",
    "            epoch (_type_): _description_\n",
    "            logs (dict, optional): _description_. Defaults to {}.\n",
    "        \"\"\"\n",
    "        if (logs.get('accuracy')>0.99):\n",
    "            # 停止条件\n",
    "            # print('\\nReached 99% accuracy so cancell traings')\n",
    "            print('\\n模型准确率已达到99%，所以将停止')\n",
    "            self.model.stop_training = True\n",
    "# 实例化自己的回调函数\n",
    "callbacks = myCallBack()\n",
    "\n",
    "# 读取手写体数字数据\n",
    "mnist = tf.keras.datasets.mnist\n",
    "\n",
    "# 分成训练集和测试集\n",
    "(x_train, y_train),(x_test, y_test) = mnist.load_data()\n",
    "\n",
    "# 预留数据形状\n",
    "print(x_train.shape)\n",
    "print(y_train.shape)\n",
    "\n",
    "# 定义模型\n",
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Flatten(input_shape=(28, 28)),  # 第一层输入层 输入的每张图片的像素为(28, 28)\n",
    "    tf.keras.layers.Dense(512,activation=tf.nn.relu),      # 中间层 有512个神经元 同时指定激活函数为relu\n",
    "    tf.keras.layers.Dense(10, activation=tf.nn.softmax)   # 输出层需要是个类别，因为0-9共10种数字，激活函数是softmax\n",
    "\n",
    "])\n",
    "# 配置模型的优化器和损失函数\n",
    "model.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# 使用训练数据进行训练，并指定最大迭代次数为10，自动停止函数为callbacks\n",
    "model.fit(x_train,y_train,epochs=10,callbacks=[callbacks])  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从上面的训练过程中可以看到，比较慢，所以整体归一化到0-1之间"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 28, 28)\n",
      "(60000,)\n",
      "Train on 60000 samples\n",
      "Epoch 1/10\n",
      "60000/60000 [==============================] - 7s 110us/sample - loss: 0.2015 - accuracy: 0.9407 - loss: 0.2251 - ac - ETA: 1s - loss: 0.2171 - ac - ETA: 0s - loss: 0.2099 \n",
      "Epoch 2/10\n",
      "60000/60000 [==============================] - 6s 93us/sample - loss: 0.0810 - accuracy: 0.9747\n",
      "Epoch 3/10\n",
      "60000/60000 [==============================] - 5s 90us/sample - loss: 0.0526 - accuracy: 0.9834\n",
      "Epoch 4/10\n",
      "60000/60000 [==============================] - 6s 92us/sample - loss: 0.0376 - accuracy: 0.9879s - loss:\n",
      "Epoch 5/10\n",
      "59616/60000 [============================>.] - ETA: 0s - loss: 0.0287 - accuracy: 0.9910\n",
      "模型准确率已达到99%，所以将停止\n",
      "60000/60000 [==============================] - 5s 91us/sample - loss: 0.0287 - accuracy: 0.9910\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x1fc1a9ee548>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# 定义callback类\n",
    "class myCallBack(tf.keras.callbacks.Callback):\n",
    "    \"\"\"自定义的停止类\n",
    "\n",
    "    Args:\n",
    "        tf (_type_): 必须继承的类\n",
    "    \"\"\"\n",
    "    def on_epoch_end(self, epoch,logs={}):\n",
    "        \"\"\"需要实现找个方法，才能停止\n",
    "\n",
    "        Args:\n",
    "            epoch (_type_): _description_\n",
    "            logs (dict, optional): _description_. Defaults to {}.\n",
    "        \"\"\"\n",
    "        if (logs.get('accuracy')>0.99):\n",
    "            # 停止条件\n",
    "            # print('\\nReached 99% accuracy so cancell traings')\n",
    "            print('\\n模型准确率已达到99%，所以将停止')\n",
    "            self.model.stop_training = True\n",
    "# 实例化自己的回调函数\n",
    "callbacks = myCallBack()\n",
    "\n",
    "# 读取手写体数字数据\n",
    "mnist = tf.keras.datasets.mnist\n",
    "\n",
    "# 分成训练集和测试集\n",
    "(x_train, y_train),(x_test, y_test) = mnist.load_data()\n",
    "x_train, x_test = x_train/255.0, x_test/255.0\n",
    "\n",
    "# 预留数据形状\n",
    "print(x_train.shape)\n",
    "print(y_train.shape)\n",
    "\n",
    "# 定义模型\n",
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Flatten(input_shape=(28, 28)),  # 第一层输入层 输入的每张图片的像素为(28, 28)\n",
    "    tf.keras.layers.Dense(512,activation=tf.nn.relu),      # 中间层 有512个神经元 同时指定激活函数为relu\n",
    "    tf.keras.layers.Dense(10, activation=tf.nn.softmax)   # 输出层需要是个类别，因为0-9共10种数字，激活函数是softmax\n",
    "\n",
    "])\n",
    "# 配置模型的优化器和损失函数\n",
    "model.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# 使用训练数据进行训练，并指定最大迭代次数为10，自动停止函数为callbacks\n",
    "model.fit(x_train,y_train,epochs=10,callbacks=[callbacks])  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_7\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten_7 (Flatten)          (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense_14 (Dense)             (None, 512)               401920    \n",
      "_________________________________________________________________\n",
      "dense_15 (Dense)             (None, 10)                5130      \n",
      "=================================================================\n",
      "Total params: 407,050\n",
      "Trainable params: 407,050\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n",
      "5130\n"
     ]
    }
   ],
   "source": [
    "# 说明\n",
    "\n",
    "print(784 == 28*28)  # 每个图片都是28*28的像素\n",
    "print(401920 == (784+1) * 512)  # 神经元层为 (上一层给+1) 再乘以 521个神经元\n",
    "print((512+1)*10)  # 上一层的(521个神经元+ 1个bias) * 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "评估训练效果-整个模型进行评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000/10000 [==============================] - 1s 94us/sample - loss: 0.0685 - accuracy: 0.9809\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.06853834754950949, 0.9809]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(x_test,y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如上accuracy为0.9809，比训练集的accuracy=0.9910稍差一些"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试训练结果-预测某个图片的类别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAANzUlEQVR4nO3df6zV9X3H8dcL5IdFVBiMMSRaLMRiF6G9oXV1m8a1s/xRbLK5ks5hY3O7rG5tQtIat6Q2/RGzVN2WNV1oJaWLP+L8UVlqOpHaOFuCXhwFhLZQhyvsChJuB24ZcK/v/XG/NFe93++5nPM9P+T9fCQ355zv+3y/33eOvvie8/2c7/k4IgTg7Dep2w0A6AzCDiRB2IEkCDuQBGEHkjinkzub6mkxXTM6uUsglf/T/+hknPB4tZbCbvs6SX8nabKkb0bEHVXPn64Zeq+vbWWXACpsjc2ltabfxtueLOlrkj4kaamk1baXNrs9AO3Vymf2FZL2RcSLEXFS0gOSVtXTFoC6tRL2BZJ+MebxgWLZ69jutz1ge+CUTrSwOwCtaPvZ+IhYFxF9EdE3RdPavTsAJVoJ+0FJC8c8vqhYBqAHtRL25yQttv1221MlfVTSxnraAlC3pofeImLY9i2S/lWjQ2/rI+KF2joDUKuWxtkj4nFJj9fUC4A24uuyQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dGfkkZz9n/pysr6yPTyyTnnXv5K5bpbrni4qZ5Ou/T7H6+sz3z23NLavL//UUv7xpnhyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3gOGvru4sr5r2T+0bd+nyofoJ+Qn13yzsn5v3/zS2oObfq9y3ZE9e5vqCePjyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3gGNxtF/uOyBtu37H3+5qLJ+15YPVNYvubj6evgnlj5SWf/YzMHS2pdvmlO57qLPMc5ep5bCbnu/pOOSRiQNR0RfHU0BqF8dR/ZrIuJIDdsB0EZ8ZgeSaDXsIekJ29ts94/3BNv9tgdsD5zSiRZ3B6BZrb6NvyoiDtr+dUmbbP8kIp4e+4SIWCdpnSSd79ktXnYBoFktHdkj4mBxe1jSo5JW1NEUgPo1HXbbM2zPPH1f0gcl7aqrMQD1auVt/DxJj9o+vZ37IuJ7tXT1FjN87Xsq69+/4msNtjClsvq3Q0sq60/9ccWI538drlx3ydBAZX3S9OmV9a9s/a3K+m1zdpbWhmcNV66LejUd9oh4UdIVNfYCoI0YegOSIOxAEoQdSIKwA0kQdiAJLnGtwasLplbWJzX4N7XR0NoPPlw9vDXy4k8r663Y94XllfX7Zt/ZYAvTSisXfY9jTSfxagNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoyz1+DCb2+prP/hwJ9U1j10rLI+PLj/TFuqzSdWPllZP29S+Tg6egtHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2DhjZ/bNut1Bq/5evrKzffOFXG2yh+qem1w6+r7Q288k9leuONNgzzgxHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2s9wvb6weR//hn1aPo18wqXocfcuJyZX17V8q/935c489W7ku6tXwyG57ve3DtneNWTbb9ibbe4vbWe1tE0CrJvI2/luSrnvDslslbY6IxZI2F48B9LCGYY+IpyUdfcPiVZI2FPc3SLq+3rYA1K3Zz+zzImKwuP+ypHllT7TdL6lfkqbrbU3uDkCrWj4bHxEhKSrq6yKiLyL6plRM8gegvZoN+yHb8yWpuD1cX0sA2qHZsG+UtKa4v0bSY/W0A6BdGn5mt32/pKslzbF9QNLnJd0h6UHbN0t6SdIN7WwSzTvy7tJPWJIaj6M3suYHn6isL/kOY+m9omHYI2J1SenamnsB0EZ8XRZIgrADSRB2IAnCDiRB2IEkuMT1LHBy08WltS2X3dlg7eqhtyu2rKmsv3Ptzyvr/Bx07+DIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM7+FnDOoksq6198xz+X1mY1uIR124nqfV/8xeqR8pGhoeoNoGdwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnfwu49MGDlfXlU5v/N3v15j+rrC/58XNNbxu9hSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHsPGFpzZWX9C/Ma/fb7tNLKmv2/X7nmOz+7r7LO776fPRoe2W2vt33Y9q4xy263fdD29uJvZXvbBNCqibyN/5ak68ZZfndELCv+Hq+3LQB1axj2iHha0tEO9AKgjVo5QXeL7R3F2/xZZU+y3W97wPbAKTX4wTMAbdNs2L8u6VJJyyQNSio9gxQR6yKiLyL6plScSALQXk2FPSIORcRIRLwm6RuSVtTbFoC6NRV22/PHPPyIpF1lzwXQGxqOs9u+X9LVkubYPiDp85Kutr1MUkjaL+mT7Wvxre+cBb9ZWf+dv9xaWT9vUvMff7bsfkdlfckQ16tn0TDsEbF6nMX3tKEXAG3E12WBJAg7kARhB5Ig7EAShB1IgktcO2DPbQsr69/5jX9pafvX7Pyj0hqXsOI0juxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7B2w7cN3N3hGa7/gc8Gfv1ZaGx4aamnbOHtwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwucmndBaW3KyQUd7OTNRl45UlqLE9XTgXla9fcPJs+d01RPkjQy98LK+t61U5ve9kTEiEtrl/1Fg98gOHasqX1yZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwt896H13W6h1G//+3iTAI86cuj8ynVnzT1eWd/6nvua6qnXLf3rWyrriz67pantNjyy215o+ynbu22/YPvTxfLZtjfZ3lvczmqqAwAdMZG38cOS1kbEUknvk/Qp20sl3Sppc0QslrS5eAygRzUMe0QMRsTzxf3jkvZIWiBplaQNxdM2SLq+TT0CqMEZfWa3fYmk5ZK2SpoXEYNF6WVJ80rW6ZfUL0nT9bamGwXQmgmfjbd9nqSHJX0mIl73TfyICEkx3noRsS4i+iKib0qLP6wIoHkTCrvtKRoN+r0R8Uix+JDt+UV9vqTD7WkRQB0avo23bUn3SNoTEXeNKW2UtEbSHcXtY23p8CywavfHKuub3/VQhzrpvB8tv79r+/7fOFlaOxXlP789ESt33FRZ/+/tzV9+u+CZ4abXrTKRz+zvl3SjpJ22txfLbtNoyB+0fbOklyTd0JYOAdSiYdgj4hlJZVfaX1tvOwDaha/LAkkQdiAJwg4kQdiBJAg7kASXuHbAuX/wH5X1y79SfUljtPG/0szLjlbW23kZ6eX/9vHKevznjJa2v+ihV8uLz+5saduztLelejdwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJDz6IzOdcb5nx3vNhXJAu2yNzToWR8e9SpUjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTRMOy2F9p+yvZu2y/Y/nSx/HbbB21vL/5Wtr9dAM2ayPQDw5LWRsTztmdK2mZ7U1G7OyK+2r72ANRlIvOzD0oaLO4ft71H0oJ2NwagXmf0md32JZKWS9paLLrF9g7b623PKlmn3/aA7YFTOtFatwCaNuGw2z5P0sOSPhMRxyR9XdKlkpZp9Mh/53jrRcS6iOiLiL4pmtZ6xwCaMqGw256i0aDfGxGPSFJEHIqIkYh4TdI3JK1oX5sAWjWRs/GWdI+kPRFx15jl88c87SOSdtXfHoC6TORs/Psl3Shpp+3txbLbJK22vUxSSNov6ZNt6A9ATSZyNv4ZSeP9DvXj9bcDoF34Bh2QBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR0Tndma/IumlMYvmSDrSsQbOTK/21qt9SfTWrDp7uzgi5o5X6GjY37RzeyAi+rrWQIVe7a1X+5LorVmd6o238UAShB1IotthX9fl/Vfp1d56tS+J3prVkd66+pkdQOd0+8gOoEMIO5BEV8Ju+zrbP7W9z/at3eihjO39tncW01APdLmX9bYP2941Ztls25ts7y1ux51jr0u99cQ03hXTjHf1tev29Ocd/8xue7Kkn0n6gKQDkp6TtDoidne0kRK290vqi4iufwHD9u9KelXStyPiXcWyv5F0NCLuKP6hnBURn+uR3m6X9Gq3p/EuZiuaP3aacUnXS7pJXXztKvq6QR143bpxZF8haV9EvBgRJyU9IGlVF/roeRHxtKSjb1i8StKG4v4Gjf7P0nElvfWEiBiMiOeL+8clnZ5mvKuvXUVfHdGNsC+Q9Isxjw+ot+Z7D0lP2N5mu7/bzYxjXkQMFvdfljSvm82Mo+E03p30hmnGe+a1a2b681Zxgu7NroqId0v6kKRPFW9Xe1KMfgbrpbHTCU3j3SnjTDP+K9187Zqd/rxV3Qj7QUkLxzy+qFjWEyLiYHF7WNKj6r2pqA+dnkG3uD3c5X5+pZem8R5vmnH1wGvXzenPuxH25yQttv1221MlfVTSxi708Sa2ZxQnTmR7hqQPqvemot4oaU1xf42kx7rYy+v0yjTeZdOMq8uvXdenP4+Ijv9JWqnRM/I/l/RX3eihpK9Fkn5c/L3Q7d4k3a/Rt3WnNHpu42ZJvyZps6S9kp6UNLuHevsnSTsl7dBosOZ3qberNPoWfYek7cXfym6/dhV9deR14+uyQBKcoAOSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJP4fcKgKSCYRzXYAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.imshow(x_test[1])  # 显示图片\n",
    "print(y_test[1])  # 原本的标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.05575686, 0.06803183, 0.09612951, 0.06050541, 0.08134149,\n",
       "        0.2966615 , 0.10483654, 0.07256926, 0.0989917 , 0.06517592]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用模型预测一个图片，这个在2.1.0是可以的\n",
    "model.predict([[x_test[1]/255.0]])   # 所有数字的可能性 这个在2.1.0是可以的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.05575686, 0.06803183, 0.09612951, 0.06050541, 0.08134149,\n",
       "        0.2966615 , 0.10483654, 0.07256926, 0.0989917 , 0.06517592]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用模型预测一个图片，这个在2.3.0就不可以\n",
    "model.predict(tf.reshape(x_test[1]/255.0,(1,28,28)))   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "print(np.argmax(model.predict(tf.reshape(x_test[1]/255.0,(1,28,28)))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train[5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "存在疑问如何找到最有可能数字的标签，通过5的索引数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "23b34fee034e6c6559cf1b732e166f57ae608166b2226a83e90a4ecbc0876e39"
  },
  "kernelspec": {
   "display_name": "Python 3.7.13 ('learn_tensorflow')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
